{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "12153d60",
   "metadata": {},
   "source": [
    "This notebook does the same experiments as `03_Mitigation.ipynb`, with fewer generation methods and a larger pool of hypotheses (50 instead of 10)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa5262f6",
   "metadata": {},
   "source": [
    "# Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd5bfe1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to the translation model and its vocabulary, in order to compute ALTI correctly. \n",
    "MODEL_DIR = '../model'\n",
    "DATA_DIR = '../model/wmt18_de-en'\n",
    "LASER_DIR = '../laser'\n",
    "USE_GPU = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf323269",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'\n",
    "os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'\n",
    "# this is for comet to behave"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60466e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f80f40e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.use_deterministic_algorithms(False) # otherwise, comet complains\n",
    "#!pip install unbabel-comet==1.1.2 --use-feature=2020-resolver\n",
    "import comet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f99885a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairseq.models.transformer import TransformerModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5d4d6a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stopes.eval.alti.wrappers.transformer_wrapper import FairseqTransformerHub\n",
    "from stopes.eval.alti.alti_metrics.alti_metrics_utils import compute_alti_metrics, compute_alti_nllb, get_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea78bfb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stopes.modules.preprocess.laser_sentence_encoder import SentenceEncoder, spm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0961ef43",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fbc7943",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "def cleanup():\n",
    "    gc.collect()\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d8dd374",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm, trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34499932",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e1a59b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sacrebleu import CHRF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a26cee22",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f99a9ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "gt = pd.read_csv('../annotated_data/guerreiro2022_corpus_w_annotations.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe9e5cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "gt['any_mistake'] = 1 - gt.correctness\n",
    "gt['any_detached'] = gt[['strong-unsupport', 'full-unsupport']].max(1)\n",
    "gt['repeat_or_detached'] = gt[['repetitions', 'strong-unsupport', 'full-unsupport']].max(1)\n",
    "gt['other_errors'] = gt['any_mistake']-gt['named-entities']-gt['omission']-gt['repeat_or_detached']\n",
    "gt['error_class'] = gt['any_detached'] + gt['full-unsupport'] + gt['any_mistake']\n",
    "gt['error_class'].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18253fe6",
   "metadata": {},
   "source": [
    "Sample 400 source texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880e9a4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "smpl = gt.groupby('error_class').sample(100, random_state=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8db2f27",
   "metadata": {},
   "source": [
    "# Creating the translations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03cf74da",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_BEAM_SIZE = 10\n",
    "MAX_HYP_NUMBER = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32fa821a",
   "metadata": {},
   "outputs": [],
   "source": [
    "de2en = TransformerModel.from_pretrained(\n",
    "    MODEL_DIR,\n",
    "    checkpoint_file='checkpoint_best.pt',\n",
    "    data_name_or_path=DATA_DIR,\n",
    "    bpe='sentencepiece', \n",
    "    sentencepiece_model=MODEL_DIR + '/sentencepiece.joint.bpe.model'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3e21e46",
   "metadata": {},
   "outputs": [],
   "source": [
    "de2en.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9998579",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Diverse translations of the data sample, key: list of lists of translation hypotheses.\n",
    "smpl_diverse = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c8844cc",
   "metadata": {},
   "source": [
    "### Baseline translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "720f11f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_tran = [de2en.translate(t, beam=5) for t in tqdm(smpl.src)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "220f2cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'default'\n",
    "smpl_diverse[key] = [[mt] for mt in new_tran]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fd965a0",
   "metadata": {},
   "source": [
    "### Random sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3080940",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'sampling_p08'\n",
    "smpl_diverse[key] = []\n",
    "for text in tqdm(smpl.src):\n",
    "    options = []\n",
    "    enc = [de2en.encode(text)]\n",
    "    for _ in range(MAX_HYP_NUMBER):\n",
    "        batched_hypos = de2en.generate(enc, sampling=True, sampling_topp=0.8, beam=1)\n",
    "        out_texts = [de2en.decode(h['tokens']) for h in batched_hypos[0]]\n",
    "        options.append(out_texts[0])\n",
    "    smpl_diverse[key].append(options)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99b86d9e",
   "metadata": {},
   "source": [
    "### Beam search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c474156",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'beam_search'\n",
    "smpl_diverse[key] = []\n",
    "for text in tqdm(smpl.src):\n",
    "    options = []\n",
    "    enc = [de2en.encode(text)]\n",
    "    batched_hypos = de2en.generate(enc, beam=MAX_HYP_NUMBER)\n",
    "    out_texts = [de2en.decode(h['tokens']) for h in batched_hypos[0]]\n",
    "    smpl_diverse[key].append(out_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffab7486",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'beam_diversity_1'\n",
    "smpl_diverse[key] = []\n",
    "for text in tqdm(smpl.src):\n",
    "    options = []\n",
    "    enc = [de2en.encode(text)]\n",
    "    batched_hypos = de2en.generate(enc, beam=MAX_HYP_NUMBER, diversity_rate=1.0)\n",
    "    out_texts = [de2en.decode(h['tokens']) for h in batched_hypos[0]]\n",
    "    smpl_diverse[key].append(out_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9dd6af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'beam_dbs_3'\n",
    "smpl_diverse[key] = []\n",
    "for text in tqdm(smpl.src):\n",
    "    options = []\n",
    "    enc = [de2en.encode(text)]\n",
    "    batched_hypos = de2en.generate(enc, beam=MAX_HYP_NUMBER, diverse_beam_groups=MAX_HYP_NUMBER, diverse_beam_strength=3)\n",
    "    out_texts = [de2en.decode(h['tokens']) for h in batched_hypos[0]]\n",
    "    smpl_diverse[key].append(out_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "552fc569",
   "metadata": {},
   "source": [
    "### Dropout methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c19e0986",
   "metadata": {},
   "outputs": [],
   "source": [
    "for mn, m in de2en.named_modules():  # an easy way to randomize the model!\n",
    "    if 'dropout' in mn:\n",
    "        m.apply_during_inference = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f894bd4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'beam_dropout'\n",
    "smpl_diverse[key] = []\n",
    "for text in tqdm(smpl.src):\n",
    "    options = []\n",
    "    enc = [de2en.encode(text)]\n",
    "    for _ in range(MAX_HYP_NUMBER):\n",
    "        batched_hypos = de2en.generate(enc, beam=MAX_BEAM_SIZE, retain_dropout=True)\n",
    "        out_texts = [de2en.decode(h['tokens']) for h in batched_hypos[0]]\n",
    "        options.append(out_texts[0])\n",
    "    smpl_diverse[key].append(options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db03865d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for mn, m in de2en.named_modules():\n",
    "    if 'dropout' in mn:\n",
    "        m.apply_during_inference = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d29198f",
   "metadata": {},
   "source": [
    "# Scoring the hypotheses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfc7be4c",
   "metadata": {},
   "source": [
    "Here is the mean quality of outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b815f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypotheses_scores = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c05b372a",
   "metadata": {},
   "source": [
    "### By LABSE scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d0ecc09",
   "metadata": {},
   "outputs": [],
   "source": [
    "labse = SentenceTransformer('sentence-transformers/LaBSE')\n",
    "labse.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77a1f0eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_pair(src, trg):\n",
    "    embs = labse.encode([src, trg], show_progress_bar=False)\n",
    "    return embs[0].dot(embs[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ae3384",
   "metadata": {},
   "outputs": [],
   "source": [
    "def argmax(values, criterion):\n",
    "    best = -np.infty\n",
    "    candidate = None\n",
    "    for v in values:\n",
    "        score = criterion(v)\n",
    "        if score > best:\n",
    "            best = score\n",
    "            candidate = v\n",
    "    return candidate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47e017d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypotheses_scores['LABSE'] = {\n",
    "    k: [[score_pair(x, smpl.iloc[i].src) for x in hyps] for i, hyps in enumerate(vs)]\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f058c1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "labse.to('cpu')\n",
    "cleanup();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f6f8f1b",
   "metadata": {},
   "source": [
    "### By COMET-QE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b289e757",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = comet.download_model(\"wmt20-comet-qe-da-v2\")\n",
    "model = comet.load_from_checkpoint(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c637c01c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def score_pair(src, trg):\n",
    "    seg_scores, sys_score = model.predict([{'src': src, 'mt': trg}], batch_size=8, gpus=0)\n",
    "    # with 0 gpus, this is actually faster\n",
    "    return seg_scores[0]\n",
    "\n",
    "print(score_pair('hallo Welt', 'hello world'))\n",
    "print(score_pair('hello world', 'hallo Welt'))\n",
    "print(score_pair('hallo Welt', 'halo over my head'))\n",
    "print(score_pair('halo over my head', 'hallo Welt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f03336af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_pairs(src, trg, batch_size=8, gpus=0):\n",
    "    seg_scores, sys_score = model.predict(\n",
    "        [{'src': s, 'mt': t} for s, t in zip(src, trg)], \n",
    "        batch_size=batch_size, \n",
    "        gpus=gpus,\n",
    "    )\n",
    "    # with 0 gpus, this is actually faster\n",
    "    return seg_scores\n",
    "\n",
    "\n",
    "def get_scores_batched(hyp_sets, sources, **kwargs):\n",
    "    srcs = []\n",
    "    tgts = []\n",
    "    ids = []\n",
    "    for i, (hyp_set, src) in enumerate(zip(hyp_sets, sources)):\n",
    "        id_old = len(srcs)\n",
    "        for mt in hyp_set:\n",
    "            srcs.append(src)\n",
    "            tgts.append(mt)\n",
    "        ids.append((id_old, len(srcs)))\n",
    "    scores = score_pairs(srcs, tgts, **kwargs)\n",
    "    results = []\n",
    "    for i, (start, end) in enumerate(ids):\n",
    "        results.append(scores[start:end])\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "484c13f4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "hypotheses_scores['COMET-QE'] = {\n",
    "    k: get_scores_batched(vs, smpl.src.tolist(), gpus=1)\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27bc64df",
   "metadata": {},
   "source": [
    "### By LASER2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66bb8289",
   "metadata": {},
   "outputs": [],
   "source": [
    "spm_tokenizer = spm.SentencePieceProcessor()\n",
    "spm_tokenizer.Load(LASER_DIR + '/laser2.spm')\n",
    "\n",
    "laser_encoder = SentenceEncoder(\n",
    "    LASER_DIR + '/laser2.pt',\n",
    "    max_sentences=None,\n",
    "    max_tokens=None,\n",
    "    spm_vocab=LASER_DIR + '/laser2.pt',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a8d99ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_sents(sents):\n",
    "    tokenized_sents = [\n",
    "        \" \".join(spm_tokenizer.EncodeAsPieces(sent))\n",
    "        for sent in sents\n",
    "    ]\n",
    "    emb = laser_encoder.encode_sentences(tokenized_sents)\n",
    "    return emb / ((emb**2).sum(1, keepdims=True) ** 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2930761d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_pair(src, trg):\n",
    "    embs = encode_sents([src, trg])\n",
    "    return embs[0].dot(embs[1])\n",
    "\n",
    "print(score_pair('hallo Welt', 'hello world'))\n",
    "print(score_pair('hallo Welt', 'halo over my hed'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b881041f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypotheses_scores['LASER2'] = {\n",
    "    k: [[score_pair(x, smpl.iloc[i].src) for x in hyps] for i, hyps in enumerate(vs)]\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b3c9652",
   "metadata": {},
   "outputs": [],
   "source": [
    "laser_encoder.encoder.to('cpu');\n",
    "cleanup();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf354e79",
   "metadata": {},
   "source": [
    "### By ALTI+ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2718f8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hub = FairseqTransformerHub(cfg=de2en.cfg, models=de2en.models, task=de2en.task)\n",
    "hub.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42864053",
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_pair(src, trg):\n",
    "    with torch.inference_mode():\n",
    "        alti = compute_alti_nllb(hub, src, trg)\n",
    "    scores = compute_alti_metrics(*alti)\n",
    "    return scores['avg_sc']\n",
    "\n",
    "print(score_pair('hallo Welt', 'hello world'))\n",
    "print(score_pair('hallo Welt', 'halo over my head'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe313bef",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypotheses_scores['ALTI_avg_sc'] = {\n",
    "    k: [[score_pair(x, smpl.iloc[i].src) for x in hyps] for i, hyps in enumerate(vs)]\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a36cccec",
   "metadata": {},
   "outputs": [],
   "source": [
    "hub.to('cpu');\n",
    "cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b866ee4",
   "metadata": {},
   "source": [
    "### By XNLI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8569208",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mname = 'joeddav/xlm-roberta-large-xnli'\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(mname).cuda()\n",
    "tokenizer = AutoTokenizer.from_pretrained(mname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22e78f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_clf_scores(texts1, texts2, batch_size=32, label='entailment', verbose=True):\n",
    "    scores = []\n",
    "    t = trange if verbose else range\n",
    "    for i in t(0, len(texts1), batch_size):\n",
    "        xx, yy = texts1[i:i+batch_size], texts2[i:i+batch_size]\n",
    "        with torch.inference_mode():\n",
    "            inputs = tokenizer(xx, yy, truncation=True, padding=True, return_tensors='pt').to(model.device)\n",
    "            proba = torch.softmax(model(**inputs).logits, -1)[:, model.config.label2id[label]].cpu().numpy()\n",
    "        scores.append(proba)\n",
    "    scores = np.concatenate(scores)\n",
    "    return scores\n",
    "\n",
    "def get_nli_scores(texts1, texts2, verbose=True):\n",
    "    return get_clf_scores(texts1, texts2, verbose=verbose) * get_clf_scores(texts2, texts1, verbose=verbose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21c9455",
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_pair(src, trg):\n",
    "    return get_nli_scores([src], [trg], verbose=False)[0]\n",
    "\n",
    "print(score_pair('hallo Welt', 'hello world'))\n",
    "print(score_pair('hallo Welt', 'halo over my head'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97f75cce",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypotheses_scores['XNLI'] = {\n",
    "    k: [[score_pair(x, smpl.iloc[i].src) for x in hyps] for i, hyps in enumerate(vs)]\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f795037a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.to('cpu')\n",
    "cleanup();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "816460ca",
   "metadata": {},
   "source": [
    "#### By ref-ChrF++ (oracle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5693f2bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "chrfpp = CHRF(word_order=2)\n",
    "\n",
    "hypotheses_scores['ref_chrf'] = {\n",
    "    k: [[chrfpp.sentence_score(x, [smpl.iloc[i].ref]).score for x in hyps] for i, hyps in enumerate(vs)]\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d533c8f",
   "metadata": {},
   "source": [
    "### Now compute the selections based on the hypotheses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc0f795",
   "metadata": {},
   "outputs": [],
   "source": [
    "selections = {\n",
    "    score_method: {\n",
    "        gen_method: [\n",
    "            hyps[np.argmax(hypotheses_scores[score_method][gen_method][i])]\n",
    "            for i, hyps in enumerate(hyps_list)\n",
    "        ]\n",
    "        for gen_method, hyps_list in smpl_diverse.items()\n",
    "    } \n",
    "    for score_method in hypotheses_scores\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3af77d72",
   "metadata": {},
   "source": [
    "### The reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cf58654",
   "metadata": {},
   "outputs": [],
   "source": [
    "selections['ref'] = {\n",
    "    k: smpl.ref.tolist()\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1092424a",
   "metadata": {},
   "source": [
    "### By default "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fcd233f",
   "metadata": {},
   "outputs": [],
   "source": [
    "selections['first'] = {\n",
    "    k: [hyps[0] for i, hyps in enumerate(vs)]\n",
    "    for k, vs in tqdm(smpl_diverse.items(), total=len(smpl_diverse))\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "933d629c",
   "metadata": {},
   "source": [
    "the baseline (default translation) corresponds to taking the first hypothesis from beam search."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "089cdb6e",
   "metadata": {},
   "source": [
    "# Evaluate the selections"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18fd342e",
   "metadata": {},
   "source": [
    "### src-NLI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72d3aeb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "mname = 'joeddav/xlm-roberta-large-xnli'\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(mname).cuda()\n",
    "tokenizer = AutoTokenizer.from_pretrained(mname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c191595c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_clf_scores(texts1, texts2, batch_size=32, label='entailment', verbose=True):\n",
    "    scores = []\n",
    "    t = trange if verbose else range\n",
    "    for i in t(0, len(texts1), batch_size):\n",
    "        xx, yy = texts1[i:i+batch_size], texts2[i:i+batch_size]\n",
    "        with torch.inference_mode():\n",
    "            inputs = tokenizer(xx, yy, truncation=True, padding=True, return_tensors='pt').to(model.device)\n",
    "            proba = torch.softmax(model(**inputs).logits, -1)[:, model.config.label2id[label]].cpu().numpy()\n",
    "        scores.append(proba)\n",
    "    scores = np.concatenate(scores)\n",
    "    return scores\n",
    "\n",
    "def get_nli_scores(texts1, texts2, verbose=True):\n",
    "    return get_clf_scores(texts1, texts2, verbose=verbose) * get_clf_scores(texts2, texts1, verbose=verbose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19a64b51",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sel_src_nli_raw = {\n",
    "    selector: {\n",
    "        sampler: get_nli_scores(sampled, smpl.src.tolist(), verbose=False).tolist()#.mean()\n",
    "        for sampler, sampled in by_sampler.items()\n",
    "    }\n",
    "    for selector, by_sampler in tqdm(selections.items())\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a42c543",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_src_nli = {k1: {k2: np.mean(v2) for k2, v2 in v1.items()} for k1, v1 in sel_src_nli_raw.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7977605",
   "metadata": {},
   "outputs": [],
   "source": [
    "bl_nli = get_nli_scores(smpl.mt.tolist(), smpl.src.tolist())\n",
    "bl_nli.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e57d0f3",
   "metadata": {},
   "source": [
    "### src-ref-COMET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba98b88b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = comet.download_model(\"wmt20-comet-da\")\n",
    "model = comet.load_from_checkpoint(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f571e691",
   "metadata": {},
   "outputs": [],
   "source": [
    "smpl_diverse.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5295403a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_for_comet = pd.DataFrame([\n",
    "    {'mt': hyp, 'src': smpl.src.iloc[i], 'ref': smpl.ref.iloc[i]}\n",
    "    for gen, by_gen in smpl_diverse.items()\n",
    "    for i, hyps in enumerate(by_gen)\n",
    "    for hyp in hyps\n",
    "] + [\n",
    "    {'mt': mt, 'src': smpl.src.iloc[i], 'ref': smpl.ref.iloc[i]}\n",
    "    for selector, by_sampler in selections.items()\n",
    "    for sampler, sampled in by_sampler.items()\n",
    "    for i, mt in enumerate(sampled)\n",
    "])\n",
    "print(data_for_comet.shape)\n",
    "data_for_comet = data_for_comet.drop_duplicates().reset_index(drop=True)\n",
    "print(data_for_comet.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49031c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "seg_scores_comet_ref, sys_score_comet_ref = model.predict(\n",
    "    data_for_comet.to_dict('records'), batch_size=32, gpus=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d37301e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts2comet = {}\n",
    "for i, row in data_for_comet.iterrows():\n",
    "    texts2comet[(row.src, row.mt, row.ref)] =  seg_scores_comet_ref[i] "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b7de843",
   "metadata": {},
   "source": [
    "Add scores for all the hypotheses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dc52834",
   "metadata": {},
   "outputs": [],
   "source": [
    "hypotheses_scores['COMET'] = {\n",
    "    gen_method: [\n",
    "        [texts2comet[(smpl.src.iloc[i], hyp, smpl.ref.iloc[i])] for hyp in hyps]\n",
    "        for i, hyps in enumerate(by_gen)\n",
    "    ]\n",
    "    for gen_method, by_gen in smpl_diverse.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5c67c2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for selector, by_sampler in selections.items():\n",
    "    for sampler, sampled in by_sampler.items():\n",
    "        for i, mt in enumerate(sampled):\n",
    "            _ = texts2comet[(smpl.src.iloc[i], mt, smpl.ref.iloc[i])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f04937ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_comet_raw = {\n",
    "    selector: {\n",
    "        sampler: [\n",
    "            texts2comet[(smpl.src.iloc[i], mt, smpl.ref.iloc[i])] \n",
    "            for i, mt in enumerate(sampled)\n",
    "        ]\n",
    "        for sampler, sampled in by_sampler.items()\n",
    "    }\n",
    "    for selector, by_sampler in selections.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d04ab5e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_comet = {k1: {k2: np.mean(v2) for k2, v2 in v1.items()} for k1, v1 in sel_comet_raw.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7282e676",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_comet = {k1: {k2: np.mean(v2) for k2, v2 in v1.items()} for k1, v1 in sel_comet_raw.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc0939c8",
   "metadata": {},
   "source": [
    "# Save the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c04650",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('../computed_data', exist_ok=True)\n",
    "with open('../computed_data/diverse-decoding-results-more-hypotheses.json', 'w') as f:\n",
    "    json.dump({\n",
    "        'data': smpl.to_dict(orient='records'),\n",
    "        'candidates': smpl_diverse,\n",
    "        'candidate_scores': {k1: {k2: \n",
    "                                  [[float(h) for h in hl] for hl in v2] \n",
    "                                  for k2, v2 in v1.items()\n",
    "                                 } for k1, v1 in hypotheses_scores.items()},\n",
    "        'selections': selections,\n",
    "        'nli_scores': sel_src_nli_raw, \n",
    "        'comet_scores': sel_comet_raw\n",
    "    }, f, ensure_ascii=False, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d29074",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
